#' Simulate the cells
#' 
#' [generate_cells()] runs simulations in order to determine the gold standard
#' of the simulations.
#' [simulation_default()] is used to configure parameters pertaining this process.
#' 
#' @param model A dyngen intermediary model for which the gold standard been generated with [generate_gold_standard()].
#' @param burn_time The burn in time of the system, used to determine an initial state vector. If `NULL`, the burn time will be inferred from the backbone.
#' @param total_time The total simulation time of the system. If `NULL`, the simulation time will be inferred from the backbone.
#' @param ssa_algorithm Which SSA algorithm to use for simulating the cells with [GillespieSSA2::ssa()]
#' @param census_interval A granularity parameter for the outputted simulation.
#' @param store_reaction_firings Whether or not to store the number of reaction firings.
#' @param store_reaction_propensities Whether or not to store the propensity values of the reactions.
#' @param compute_cellwise_grn Whether or not to compute the cellwise GRN activation values.
#' @param compute_dimred Whether to perform a dimensionality reduction after simulation.
#' @param compute_rna_velocity Whether or not to compute the propensity ratios after simulation.
#' @param experiment_params A tibble generated by rbinding multiple calls of [simulation_type_wild_type()] and [simulation_type_knockdown()].
#' @param kinetics_noise_function A function that will generate noise to the kinetics of each simulation. 
#'   It takes the `feature_info` and `feature_network` as input parameters,
#'   modifies them, and returns them as a list. See [kinetics_noise_none()] and [kinetics_noise_simple()].
#' 
#' @return A dyngen model.
#' 
#' @seealso [dyngen] on how to run a complete dyngen simulation
#' 
#' @importFrom GillespieSSA2 ssa
#' @export
#' 
#' @examples
#' library(dplyr)
#' model <- 
#'   initialise_model(
#'     backbone = backbone_bifurcating(),
#'     simulation = simulation_default(
#'       ssa_algorithm = ssa_etl(tau = .1),
#'       experiment_params = bind_rows(
#'         simulation_type_wild_type(num_simulations = 4),
#'         simulation_type_knockdown(num_simulations = 4)
#'       )
#'     )
#'   )
#' \donttest{
#' data("example_model")
#' model <- example_model %>% generate_cells()
#'   
#' plot_simulations(model)
#' plot_gold_mappings(model)
#' plot_simulation_expression(model)
#' }
generate_cells <- function(model) {
  # satisfy r cmd check
  time <- NULL
  
  if (model$verbose) cat("Precompiling reactions for simulations\n")
  model <- .add_timing(model, "6_simulations", "precompile reactions for simulations")
  reactions <- .generate_cells_precompile_reactions(model)
  
  # simulate cells one by one
  if (model$verbose) cat("Running ", nrow(model$simulation_params$experiment_params), " simulations\n", sep = "")
  model <- .add_timing(model, "6_simulations", "running simulations")
  simulations <- 
    pbapply::pblapply(
      X = seq_len(nrow(model$simulation_params$experiment_params)),
      cl = model$num_cores,
      FUN = .generate_cells_simulate_cell,
      model = model,
      reactions = reactions
    )
  
  # split up simulation data
  model <- .add_timing(model, "6_simulations", "generate output")
  model$simulations <- lst(
    meta = map_df(simulations, "meta"),
    counts = do.call(rbind, map(simulations, "counts")),
    cellwise_grn = do.call(rbind, map(simulations, "cellwise_grn")),
    reaction_firings = do.call(rbind, map(simulations, "reaction_firings")),
    reaction_propensities = do.call(rbind, map(simulations, "reaction_propensities")),
    rna_velocity = do.call(rbind, map(simulations, "rna_velocity")),
    kd_multiplier = do.call(rbind, map(simulations, "kd_multiplier")),
    perturbed_parameters = do.call(rbind, map(simulations, "perturbed_parameters"))
  )
  
  # predict state
  if (model$verbose) cat("Mapping simulations to gold standard\n", sep = "")
  model <- .add_timing(model, "6_simulations", "map simulations to gold standard")
  if (!is.null(model[["gold_standard"]])) {
    model$simulations$meta <- .generate_cells_predict_state(model)
  } else {
    model$simulations$meta <- model$simulations$meta %>% rename(sim_time = time)
  }
  
  # perform dimred
  model <- .add_timing(model, "6_simulations", "perform dimred")
  if (model$simulation_params$compute_dimred) {
    if (model$verbose) cat("Performing dimred\n", sep = "")
    model <- model %>% calculate_dimred()
  }
  
  # return
  .add_timing(model, "6_simulations", "end")
}

#' @export
#' @rdname generate_cells
#' @importFrom GillespieSSA2 ssa_etl
simulation_default <- function(
  burn_time = NULL,
  total_time = NULL,
  ssa_algorithm = ssa_etl(tau = 30 / 3600),
  census_interval = 4,
  experiment_params = bind_rows(
    simulation_type_wild_type(num_simulations = 32),
    simulation_type_knockdown(num_simulations = 0)
  ),
  store_reaction_firings = FALSE,
  store_reaction_propensities = FALSE,
  compute_cellwise_grn = FALSE,
  compute_dimred = TRUE,
  compute_rna_velocity = FALSE,
  kinetics_noise_function = kinetics_noise_simple(mean = 1, sd = .005)
) {
  lst(
    burn_time,
    total_time,
    ssa_algorithm,
    census_interval,
    experiment_params,
    store_reaction_firings,
    store_reaction_propensities,
    compute_cellwise_grn,
    compute_dimred,
    compute_rna_velocity,
    kinetics_noise_function
  )
}

#' @importFrom GillespieSSA2 compile_reactions
.generate_cells_precompile_reactions <- function(model) {
  # fetch paraneters and settings
  sim_system <- model$simulation_system
  reactions <- sim_system$reactions 
  
  buffer_ids <- unique(unlist(map(reactions, "buffer_ids")))
  
  # compile prop funs
  comp_funs <- GillespieSSA2::compile_reactions(
    reactions = reactions,
    buffer_ids = buffer_ids,
    state_ids = if (is.matrix(sim_system$initial_state)) colnames(sim_system$initial_state) else names(sim_system$initial_state),
    params = sim_system$parameters,
    hardcode_params = FALSE,
    fun_by = 1000L
  )
  
  comp_funs
}

#' Add small noise to the kinetics of each simulation
#' 
#' @param mean The mean level of noise (should be 1)
#' @param sd The sd of the noise (should be a relatively small value)
#' 
#' @return A list of noise generators for the kinetics.
#' 
#' @rdname kinetics_noise
#' @export
kinetics_noise_none <- function() {
  function(feature_info, feature_network) {
    lst(feature_info, feature_network)
  }
}

#' @rdname kinetics_noise
#' @export
kinetics_noise_simple <- function(mean = 1, sd = .005) {
  # satisfy r cmd check
  mrna_halflife <- protein_halflife <- NULL
  
  function(feature_info, feature_network) {
    feature_info <- 
      feature_info %>% 
      mutate_at(c("basal"), ~ pmin(1, . * rnorm(length(.), mean = mean, sd = sd))) %>% 
      mutate_at(c("transcription_rate", "splicing_rate", "translation_rate", "mrna_halflife", "protein_halflife"), ~ . * rnorm(length(.), mean = mean, sd = sd)) %>% 
      mutate(
        mrna_decay_rate = log(2) / mrna_halflife, 
        protein_decay_rate = log(2) / protein_halflife
      )
    
    feature_network <- 
      feature_network %>%
      mutate_at(c("strength", "hill"), ~ . * rnorm(length(.), mean = mean, sd = sd))
    
    lst(
      feature_info,
      feature_network
    )
  }
}

.generate_cells_simulate_cell <- function(simulation_i, model, reactions, verbose = FALSE, debug = FALSE) {
  # satisfy r cmd check
  time <- NULL
  
  sim_params <- model$simulation_params
  sim_system <- model$simulation_system
  expr_params <- sim_params$experiment_params %>% extract_row_to_list(simulation_i)
  
  set.seed(expr_params$seed)
  
  # randomise kinetics to simulate intercellular differences
  out <- sim_params$kinetics_noise_function(model$feature_info, model$feature_network)
  out2 <- .kinetics_calculate_dissociation(out$feature_info, out$feature_network)
  feature_info <- out2$feature_info
  feature_network <- out2$feature_network
  perturbed_parameters <- .kinetics_extract_parameters(feature_info, feature_network)
  
  # get initial state
  initial_state <- sim_system$initial_state
  if (is.matrix(initial_state)) {
    initial_state <- initial_state[simulation_i, ]
  }
  
  log_propensity <- sim_params$store_reaction_propensities || sim_params$compute_rna_velocity
  
  if (sim_params$burn_time > 0) {
    burn_reaction_firings <- reactions
    rem <- setdiff(sim_system$molecule_ids, sim_system$burn_variables)
    if (length(rem) > 0) {
      burn_reaction_firings$state_change[match(rem, sim_system$molecule_ids), ] <- 0
    }
    
    # burn in
    out <- GillespieSSA2::ssa(
      initial_state = initial_state,
      reactions = burn_reaction_firings,
      final_time = sim_params$burn_time,
      census_interval = sim_params$census_interval,
      params = perturbed_parameters,
      method = sim_params$ssa_algorithm,
      stop_on_neg_state = FALSE,
      verbose = verbose,
      log_buffer = FALSE,
      log_firings = sim_params$store_reaction_firings,
      log_propensity = log_propensity
    )
    
    burn_meta <- 
      tibble(
        time = c(head(out$time, -1), sim_params$burn_time)
      ) %>%
      mutate(time = time - max(time))
    burn_counts <- out$state %>% Matrix::Matrix(sparse = TRUE)
    
    burn_reaction_firings <- if (sim_params$store_reaction_firings) out$firings %>% Matrix::Matrix(sparse = TRUE) else NULL
    burn_reaction_propensities <- if (log_propensity) out$propensity %>% Matrix::Matrix(sparse = TRUE) else NULL
    
    new_initial_state <- out$state[nrow(out$state), ]
  } else {
    burn_meta <- NULL
    burn_counts <- NULL
    new_initial_state <- initial_state
    burn_reaction_firings <- NULL
    burn_reaction_propensities <- NULL
  }
  
  total_time <- sim_params$total_time
  if (expr_params$type == "knockdown") {
    total_time <- total_time * expr_params$timepoint
  }
  
  # actual simulation
  sim <- GillespieSSA2::ssa(
    initial_state = new_initial_state, 
    reactions = reactions,
    final_time = total_time,
    census_interval = sim_params$census_interval,
    params = perturbed_parameters,
    method = sim_params$ssa_algorithm,
    stop_on_neg_state = FALSE,
    verbose = verbose,
    log_buffer = FALSE,
    log_firings = sim_params$store_reaction_firings,
    log_propensity = log_propensity,
    return_simulator = TRUE
  )
  if (debug) return(sim)
  
  # run simulation
  sim$run()
  
  # get output
  # based on GillespieSSA2:::return_output
  out <- list(
    time = sim$output_time,
    state = sim$output_state,
    propensity = sim$output_propensity,
    firings = sim$output_firings,
    buffer = sim$output_buffer,
    stats = sim$get_statistics()
  )
  # set names of objects
  colnames(out$state) <- names(new_initial_state)
  if (sim$log_propensity) {
    colnames(out$propensity) <- reactions$reaction_ids
  }
  if (sim$log_buffer) {
    colnames(out$buffer) <- reactions$buffer_ids
  }
  if (sim$log_firings) {
    colnames(out$firings) <- reactions$reaction_ids
  }
  
  # reformat output
  meta <- 
    tibble(
      time = c(head(out$time, -1), total_time)
    )
  counts <- out$state %>% Matrix::Matrix(sparse = TRUE)
  reaction_firings <- if (sim_params$store_reaction_firings) out$firings %>% Matrix::Matrix(sparse = TRUE) else NULL
  reaction_propensities <- if (log_propensity) out$propensity %>% Matrix::Matrix(sparse = TRUE) else NULL
  kd_state <- out$state[nrow(out$state), ]
  
  # simulation of knockdown, if any
  if (expr_params$type == "knockdown") {
    assert_that(!reactions$hardcode_params)
    
    kd_total_time <- sim_params$total_time - total_time
    
    kd_gene_candidates <- expr_params$genes
    if (length(kd_gene_candidates) == 1 && kd_gene_candidates == "*") {
      kd_gene_candidates <- model$feature_info$feature_id
    }
    kd_genes <- sample(kd_gene_candidates, expr_params$num_genes)
    kd_wprs <- paste0("transcription_rate_", kd_genes)
    
    kd_params <- perturbed_parameters
    kd_params[kd_wprs] <- kd_params[kd_wprs] * expr_params$multiplier
    
    kd_multiplier <- tibble(simulation_i, gene = kd_genes, multiplier = expr_params$multiplier)
    
    out <- GillespieSSA2::ssa(
      initial_state = kd_state, 
      reactions = reactions,
      final_time = kd_total_time,
      census_interval = sim_params$census_interval,
      params = kd_params,
      method = sim_params$ssa_algorithm,
      stop_on_neg_state = FALSE,
      verbose = verbose,
      log_buffer = FALSE,
      log_firings = sim_params$store_reaction_firings,
      log_propensity = log_propensity
    )
    
    kd_meta <- 
      tibble(
        time = total_time + c(head(out$time, -1), kd_total_time)
      )
    kd_counts <- out$state %>% Matrix::Matrix(sparse = TRUE)
    kd_reaction_firings <- if (sim_params$store_reaction_firings) out$firings %>% Matrix::Matrix(sparse = TRUE) else NULL
    kd_reaction_propensities <- if (log_propensity) out$propensity %>% Matrix::Matrix(sparse = TRUE) else NULL
  } else {
    kd_meta <- NULL
    kd_counts <- NULL
    kd_reaction_firings <- NULL
    kd_reaction_propensities <- NULL
    kd_multiplier = NULL
  }
  
  if (!is.null(burn_meta)) {
    meta <- bind_rows(burn_meta, meta)
    counts <- rbind(burn_counts, counts)
    reaction_firings <- rbind(burn_reaction_firings, reaction_firings)
    reaction_propensities <- rbind(burn_reaction_propensities, reaction_propensities)
  }
  if (!is.null(kd_meta)) {
    meta <- bind_rows(meta, kd_meta)
    counts <- rbind(counts, kd_counts)
    reaction_firings <- rbind(reaction_firings, kd_reaction_firings)
    reaction_propensities <- rbind(reaction_propensities, kd_reaction_propensities)
  }
  
  cellwise_grn <- 
    if (sim_params$compute_cellwise_grn) {
      .generate_cells_compute_cellwise_grn(feature_info, feature_network, new_initial_state, reactions, counts, sim)
    } else {
      NULL
    }
  
  rna_velocity <-
    if (sim_params$compute_rna_velocity) {
      .generate_cells_compute_rna_velocity(model, reaction_propensities)
    } else {
      NULL
    }
  
  meta <- meta %>% mutate(simulation_i) %>% select(simulation_i, everything())
  
  # only store if specifically asked for
  if (!sim_params$store_reaction_propensities) {
    reaction_propensities <- NULL
  }
  
  lst(
    meta, counts, cellwise_grn, 
    reaction_firings, reaction_propensities, 
    rna_velocity,
    kd_multiplier, perturbed_parameters
  )
}

.generate_cells_compute_cellwise_grn <- function(feature_info, feature_network, new_initial_state, reactions, counts, sim) {
  # satisfy r cmd check
  feature_id <- transcription_rate <- from <- to <- reg_y_match <- `.` <- j <- transcription_rate <- ko_effect <- NULL
  
  fn <- 
    feature_network %>% 
    left_join(feature_info %>% select(to = feature_id, transcription_rate), by = "to") %>% 
    transmute(
      reg_y_match = match(paste0("mol_protein_", from), names(new_initial_state)),
      tar_prop_match = match(paste0("transcription_", to), reactions$reaction_ids),
      j = row_number(),
      transcription_rate
    )
  
  ko_effects <- 
    map_df(
      seq_len(nrow(counts)),
      function(counti) {
        sim$state <- counts[counti,]
        sim$calculate_propensity()
        
        # df <- fn %>% filter(from == "Burn1_TF1")
        ko_effects <- 
          fn %>% 
          group_by(reg_y_match) %>% 
          do({
            df <- .
            
            if (nrow(df) == 0) return(NULL) 
            
            orig_reg_state <- sim$state[[df$reg_y_match[[1]]]]
            
            if (orig_reg_state != 0) {
              orig_tar_prop <- sim$propensity[df$tar_prop_match]
              sim$state[[df$reg_y_match[[1]]]] <- 0
              sim$calculate_propensity()
              ko_effect <- orig_tar_prop - sim$propensity[df$tar_prop_match]
            } else {
              ko_effect <- rep(0, nrow(df))
            }
            
            out <- 
              df %>% 
              transmute(
                i = counti,
                j,
                ko_effect = ko_effect / transcription_rate
              )
            
            if (orig_reg_state != 0) {
              sim$state[[df$reg_y_match[[1]]]] <- orig_reg_state
              sim$propensity[df$tar_prop_match] <- orig_tar_prop
            }
            
            out
          }) %>% 
          ungroup()
      }
    )
  
  ko_effects <- ko_effects %>% filter(abs(ko_effect) > .001)
  
  Matrix::sparseMatrix(
    i = ko_effects$i,
    j = ko_effects$j,
    x = ko_effects$ko_effect,
    dims = c(nrow(counts), nrow(feature_network)),
    dimnames = list(
      NULL,
      paste0(feature_network$from, "->", feature_network$to)
    )
  )
}

.generate_cells_predict_state <- function(model) {
  simulation_i <- time <- from <- to <- NULL
  
  # fetch gold standard data
  gs_meta <- model$gold_standard$meta
  gold_ix <- !gs_meta$burn
  gs_meta <- gs_meta[gold_ix, , drop = FALSE]
  gs_counts <- model$gold_standard$counts[gold_ix, , drop = FALSE]
  gs_dimred <- model$gold_standard$dimred[gold_ix, , drop = FALSE]
  
  # fetch simulation data
  # (gold standard counts only contains TFs, so filter those)
  sim_meta <- model$simulations$meta
  sim_counts <- model$simulations$counts[, colnames(gs_counts), drop = FALSE]
  
  # calculate 1NN -> a full distance matrix could be avoided
  if (nrow(sim_counts) > 10000) {
    start <- seq(1, nrow(sim_counts), by = 10000)
    stop <- pmin(start + 9999, nrow(sim_counts))
    best_matches <- pmap(lst(start, stop), function(start, stop) {
      dis <- dynutils::calculate_distance(gs_counts, sim_counts[start:stop, , drop = FALSE], method = model$distance_metric)
      best_match <- apply(dis, 2, which.min)
    })
    best_match <- unlist(best_matches)
  } else {
    dis <- dynutils::calculate_distance(gs_counts, sim_counts, method = model$distance_metric)
    best_match <- apply(dis, 2, which.min)
  }
  
  # add predictions to sim_meta
  sim_meta <- 
    bind_cols(
      sim_meta %>% select(simulation_i, sim_time = time),
      gs_meta[best_match, , drop = FALSE] %>% select(from, to, time)
    ) %>% 
    group_by(from, to) %>% 
    mutate(time = dynutils::scale_minmax(time)) %>% 
    ungroup()
  
  # check if all branches are present
  sim_edges <- sim_meta %>% group_by(from, to) %>% summarise(n = n())
  network <- full_join(model$gold_standard$network, sim_edges, by = c("from", "to"))
  
  if (any(is.na(network$n))) {
    warning("Simulation does not contain all gold standard edges. This simulation likely suffers from bad kinetics; choose a different seed and rerun.")
  }
  
  # return output
  sim_meta
}

.generate_cells_compute_rna_velocity <- function(model, reaction_propensities) {
  # satisfy r cmd check
  feature_id <- NULL
  
  feature_info <- model$feature_info
  sim_system <- model$simulation_system
  
  # Get the names of the genes and corresponding mRNA molecules
  feature_ids <- feature_info$feature_id
  relevant_molecules <- c(feature_info$mol_mrna, feature_info$mol_premrna)
  map <- set_names(c(feature_ids, feature_ids), relevant_molecules)
  
  # Extract for each spliced mRNA its relevant reactions.
  reaction_effects <- map_df(
    seq_along(sim_system$reactions),
    function(reaction_ix) {
      reaction <- sim_system$reactions[[reaction_ix]]
      tibble(
        reaction_ix = reaction_ix,
        feature_id = names(reaction$effect),
        effect = reaction$effect,
        name = reaction$name
      ) %>% filter(feature_id %in% relevant_molecules)
    }) %>%
    mutate(
      row = row_number(),
      feature_id = map[feature_id],
      molecule_ix = match(feature_id, feature_ids)
    )
  
  # Get the propensities of the relevant reactions
  propensities <- reaction_propensities[, reaction_effects$reaction_ix]
  
  # Get the propensities of production reactions, per mRNA
  perreaction_to_pergene <- with(
    reaction_effects,
    Matrix::sparseMatrix(i = row, j = molecule_ix, x = effect, dims = c(nrow(reaction_effects), length(feature_ids)))
  )
  rna_velocity <- propensities %*% perreaction_to_pergene
  rna_velocity@x[is.na(rna_velocity@x)] <- 0
  rna_velocity <- Matrix::drop0(rna_velocity)
  colnames(rna_velocity) <- feature_ids
  
  rna_velocity
}

#' @param num_simulations The number of simulations to run.
#' @param seed A set of seeds for each of the simulations.
#' @rdname generate_cells
#' @export
simulation_type_wild_type <- function(num_simulations, seed = sample.int(10 * num_simulations, num_simulations)) {
  if (num_simulations == 0) {
    NULL
  } else {
    tibble(
      type = "wild_type",
      seed
    )
  }
}

#' @param timepoint The relative time point of the knockdown
#' @param genes Which genes to sample from. `"*"` for all genes.
#' @param num_genes The number of genes to knockdown.
#' @param multiplier The strength of the knockdown. Use 0 for a full knockout, 0<x<1 for a knockdown, and >1 for an overexpression.
#' 
#' @rdname generate_cells
#' @export
simulation_type_knockdown <- function(
  num_simulations, 
  timepoint = runif(num_simulations),
  genes = "*", 
  num_genes = sample(1:5, num_simulations, replace = TRUE, prob = 0.25 ^ (1:5)),
  multiplier = runif(num_simulations, 0, 1), 
  seed = sample.int(10 * num_simulations, num_simulations)
) {
  if (num_simulations == 0) {
    NULL
  } else {
    assert_that(
      (is.list(genes) && length(genes) == num_simulations) ||
        is.character(genes) || is.factor(genes)
    )
    if (!is.list(genes)) genes <- as.list(rep(genes, num_simulations))
    
    tibble(
      type = "knockdown",
      timepoint,
      genes,
      num_genes,
      multiplier,
      seed
    )
  }
}